{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "# For tips on running notebooks in Google Colab, see\n",
    "# https://pytorch.org/tutorials/beginner/colab\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Learn the Basics](intro.html) \\|\\|\n",
    "[Quickstart](quickstart_tutorial.html) \\|\\|\n",
    "[Tensors](tensorqs_tutorial.html) \\|\\| [Datasets &\n",
    "DataLoaders](data_tutorial.html) \\|\\|\n",
    "[Transforms](transforms_tutorial.html) \\|\\| [Build\n",
    "Model](buildmodel_tutorial.html) \\|\\|\n",
    "[Autograd](autogradqs_tutorial.html) \\|\\|\n",
    "[Optimization](optimization_tutorial.html) \\|\\| **Save & Load Model**\n",
    "\n",
    "Save and Load the Model\n",
    "=======================\n",
    "\n",
    "In this section we will look at how to persist model state with saving,\n",
    "loading and running model predictions.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision.models as models"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Saving and Loading Model Weights\n",
    "================================\n",
    "\n",
    "PyTorch models store the learned parameters in an internal state\n",
    "dictionary, called `state_dict`. These can be persisted via the\n",
    "`torch.save` method:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "model = models.vgg16(weights='IMAGENET1K_V1')\n",
    "torch.save(model.state_dict(), 'model_weights.pth')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To load model weights, you need to create an instance of the same model\n",
    "first, and then load the parameters using `load_state_dict()` method.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "model = models.vgg16() # we do not specify ``weights``, i.e. create untrained model\n",
    "model.load_state_dict(torch.load('model_weights.pth'))\n",
    "model.eval()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div style=\"background-color: #54c7ec; color: #fff; font-weight: 700; padding-left: 10px; padding-top: 5px; padding-bottom: 5px\"><strong>NOTE:</strong></div>\n",
    "<div style=\"background-color: #f3f4f7; padding-left: 10px; padding-top: 10px; padding-bottom: 10px; padding-right: 10px\">\n",
    "<p>be sure to call <code>model.eval()</code> method before inferencing to set the dropout and batch normalization layers to evaluation mode. Failing to do this will yield inconsistent inference results.</p>\n",
    "</div>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Saving and Loading Models with Shapes\n",
    "=====================================\n",
    "\n",
    "When loading model weights, we needed to instantiate the model class\n",
    "first, because the class defines the structure of a network. We might\n",
    "want to save the structure of this class together with the model, in\n",
    "which case we can pass `model` (and not `model.state_dict()`) to the\n",
    "saving function:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "torch.save(model, 'model.pth')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can then load the model like this:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false
   },
   "outputs": [],
   "source": [
    "model = torch.load('model.pth')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div style=\"background-color: #54c7ec; color: #fff; font-weight: 700; padding-left: 10px; padding-top: 5px; padding-bottom: 5px\"><strong>NOTE:</strong></div>\n",
    "<div style=\"background-color: #f3f4f7; padding-left: 10px; padding-top: 10px; padding-bottom: 10px; padding-right: 10px\">\n",
    "<p>This approach uses Python <a href=\"https://docs.python.org/3/library/pickle.html\">pickle</a> module when serializing the model, thus it relies on the actual class definition to be available when loading the model.</p>\n",
    "</div>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Related Tutorials\n",
    "=================\n",
    "\n",
    "-   [Saving and Loading a General Checkpoint in\n",
    "    PyTorch](https://pytorch.org/tutorials/recipes/recipes/saving_and_loading_a_general_checkpoint.html)\n",
    "-   [Tips for loading an nn.Module from a\n",
    "    checkpoint](https://pytorch.org/tutorials/recipes/recipes/module_load_state_dict_tips.html?highlight=loading%20nn%20module%20from%20checkpoint)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
